/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

use core::hash::Hash;
use std::collections::HashMap;
use std::fmt::Debug;
use std::ops::Range;

/// Output from generating a permutation between two vecs of items.
/// It provides the actual permutation generated as well as information about
/// unmatched items from both the source and target arrays
#[derive(Debug, PartialEq, Eq)]
pub struct PermutationResult {
    pub perm: Vec<usize>,
    pub unmatched_source_indices: Vec<usize>,
    pub unmatched_target_indices: Vec<usize>,
}

/// Result of generating a permutation from the source vec to the target while
/// maintaining the position of values by marking unmatched values in the target
/// with `None` instead of just removing them
#[derive(Debug)]
pub struct ExpandedPermResult {
    pub perm: Vec<Option<usize>>,
    pub unmatched_source_indices: Vec<usize>,
    pub unmatched_target_indices: Vec<usize>,
}

impl ExpandedPermResult {
    /// Get a normal permutation result from the expansion by removing the
    /// unmatched target values
    fn collapse(self) -> PermutationResult {
        let Self {
            perm,
            unmatched_source_indices,
            unmatched_target_indices,
        } = self;

        PermutationResult {
            perm: perm.into_iter().flatten().collect(),
            unmatched_source_indices,
            unmatched_target_indices,
        }
    }
}
/// generating a permutation from the source vec to the target while
/// maintaining the position of values by marking unmatched values in the target
/// with `None` instead of just removing them
pub fn generate_expanded_permutation<T>(source: &[T], target: &[T]) -> ExpandedPermResult
where
    T: Hash + Eq,
{
    let mut index_source: HashMap<&T, Vec<usize>> = Default::default();
    let mut unmatched_target_indices: Vec<usize> = Default::default();
    let mut unmatched_source_indices: Vec<usize>;

    source
        .iter()
        .enumerate()
        .for_each(|(i, v)| index_source.entry(v).or_default().push(i));

    index_source.values_mut().for_each(|vec| vec.reverse());

    let result_perm = target
        .iter()
        .enumerate()
        .map(|(i, v)| (i, index_source.get_mut(v).and_then(|tp| tp.pop())))
        .map(|(i, v)| {
            v.is_none().then(|| unmatched_target_indices.push(i));
            v
        })
        .collect::<Vec<_>>();

    unmatched_source_indices = index_source
        .values()
        .flat_map(|v| v.iter())
        .cloned()
        .collect();

    unmatched_source_indices.sort();

    ExpandedPermResult {
        perm: result_perm,
        unmatched_source_indices,
        unmatched_target_indices,
    }
}

/// Generate a permutation that when applied to the first argument's elements
/// produces the second argument. If there are elements in the first that are
/// not in the second (and vice versa) then those elements will not be present
/// in the final permutation, and thus the permutation will be shorter than
/// both input slices. Information about which elements are unmatched in both
/// the source and target vectors is provided in the result
pub fn generate_permutation<T>(source: &[T], target: &[T]) -> PermutationResult
where
    T: Hash + Eq,
{
    generate_expanded_permutation(source, target).collapse()
}

/// a derangement of the type generated by `generate_permutation` has the property
/// that if any two sequential elements differ by only 1 (when going from lower index
/// to higher index), we know those two elements will not be split up on sorting. We
/// can use this fact to ignore the first and last parts of the range as long as there
/// are no gaps.
///
/// This function will find the range within the permutation where sorting has to be
/// performed. Outside of this range, no swaps will need to take place when sorting the
/// permutation
fn find_sortable_range(perm: &[usize], max_index: usize) -> Range<usize> {
    let prefix_len = perm
        .iter()
        .enumerate()
        .take_while(|(expected, value)| *expected == **value)
        .count();

    // If the permutation is already sorted and complete, the prefix length will be
    // the entire permutation, so shouldn't try to calculate a postfix length (because
    // that will also be the length of the permutation)
    let postfix_len = if prefix_len < perm.len() {
        perm.iter()
            .rev()
            .enumerate()
            .take_while(|(index_from_end, value)| (max_index - *index_from_end) == **value)
            .count()
    } else {
        0
    };

    Range {
        start: prefix_len,
        end: perm.len() - postfix_len,
    }
}

pub struct BubbleSortIter<'a, T> {
    perm: Vec<usize>,
    source: &'a [T],
    swap_history: Vec<(usize, usize)>,
    history_size: usize,
    history_cursor: usize,
    unused_value_count: usize,
}

impl<'a, T> BubbleSortIter<'a, T>
where
    T: Hash + Eq,
{
    fn new(
        perm: Vec<usize>,
        source: &'a [T],
        swap_history: Vec<(usize, usize)>,
        unused_value_count: usize,
    ) -> Self {
        BubbleSortIter {
            perm,
            source,
            history_size: swap_history.len(),
            history_cursor: swap_history.len(),
            swap_history,
            unused_value_count,
        }
    }

    /// Get the edit distance between the source and target slices. This is different from the
    /// swap distance because any elements that appear in one slice, but not the other will
    /// not contribute to the swap distance
    pub fn edit_distance(&self) -> usize {
        self.history_size + self.unused_value_count
    }

    /// Number of swaps take to permute the source slice into the target slice when values
    /// not accounting for items that appear in one slice, but not the other
    pub fn swap_distance(&self) -> usize {
        self.history_size
    }

    pub fn next<'b>(&'b mut self) -> Option<Perm<'a, 'b, T>> {
        (self.history_cursor < self.history_size).then(|| {
            let swap = self.swap_history[self.history_cursor];
            self.perm.swap(swap.0, swap.1);
            self.history_cursor += 1;

            self.current()
        })
    }

    pub fn skip(&mut self, distance: usize) {
        for _ in 0..distance {
            let _ = self.next();
        }
    }

    pub fn next_back<'b>(&'b mut self) -> Option<Perm<'a, 'b, T>> {
        (self.history_cursor > 0).then(|| {
            self.history_cursor -= 1;
            let swap = self.swap_history[self.history_cursor];
            self.perm.swap(swap.0, swap.1);

            self.current()
        })
    }

    pub fn skip_back(&mut self, distance: usize) {
        for _ in 0..distance {
            let _ = self.next_back();
        }
    }

    pub fn current<'b>(&'b self) -> Perm<'a, 'b, T> {
        Perm::new(self)
    }

    pub fn midpoint<'b>(&'b mut self) -> Perm<'a, 'b, T> {
        let m = self.history_size / 2;
        if m > self.history_cursor {
            self.skip(m - self.history_cursor);
        } else {
            self.skip_back(self.history_cursor - m);
        }
        self.current()
    }

    pub fn interpolate<'b>(&'b mut self, interpolate_percentage: u8) -> Perm<'a, 'b, T> {
        let m = (self.history_size * (interpolate_percentage as usize)) / 100;
        if m > self.history_cursor {
            self.skip(m - self.history_cursor);
        } else {
            self.skip_back(self.history_cursor - m);
        }
        self.current()
    }

    pub fn history_size(&self) -> usize {
        self.history_size
    }
}

pub struct Perm<'a, 'b, T> {
    perm: &'b Vec<usize>,
    source: &'a [T],
    i: usize,
}

impl<'a, 'b, T> Perm<'a, 'b, T> {
    fn new(parent: &'b BubbleSortIter<'a, T>) -> Self {
        Perm {
            perm: &parent.perm,
            source: parent.source,
            i: 0,
        }
    }
}

impl<'a, 'b, T> Iterator for Perm<'a, 'b, T>
where
    T: Hash + Eq,
{
    type Item = &'a T;

    fn next(&mut self) -> Option<Self::Item> {
        (self.i < self.perm.len()).then(|| {
            let result = &self.source[self.perm[self.i]];
            self.i += 1;
            result
        })
    }
}

/// Perform the bubble sort algorithm and return the record of each pairwise swap to perform
/// the sort
fn bubble_sort_history<T>(perm: &mut [usize], source: &[T]) -> Vec<(usize, usize)> {
    let mut history = vec![];
    let sortable_range = find_sortable_range(perm, source.len() - 1);
    let index_offset = sortable_range.start;
    let sortable_slice = &mut perm[sortable_range];
    let size = sortable_slice.len();

    if size == 0 {
        return history;
    }

    for i in 0..(size - 1) {
        let mut swapped = false;
        for j in 0..(size - i - 1) {
            if sortable_slice[j] > sortable_slice[j + 1] {
                // swapping occurs if elements
                // are not in the intended order
                // this is save because both the
                sortable_slice.swap(j, j + 1);

                history.push((index_offset + j, index_offset + j + 1));
                swapped = true;
            }
        }

        // no swapping means the array is already sorted
        // so no need for further comparison
        if !swapped {
            break;
        }
    }

    history
}

pub fn bubble_sort_distance<T>(source: &[T], target: &[T]) -> usize
where
    T: PartialEq + Hash + Eq + Debug + Clone + Copy + 'static,
{
    iterable_bubble_sort(source, target).history_size
}

pub fn iterable_bubble_sort_from_perm<T>(
    source: &[T],
    permutation_result: PermutationResult,
) -> BubbleSortIter<T>
where
    T: Hash + Eq,
{
    let PermutationResult {
        mut perm,
        unmatched_source_indices,
        unmatched_target_indices,
    } = permutation_result;

    let history = bubble_sort_history(&mut perm, source);

    let unused_values = unmatched_source_indices.len() + unmatched_target_indices.len();

    BubbleSortIter::new(perm, source, history, unused_values)
}

pub fn iterable_bubble_sort<'a, T>(source: &'a [T], target: &'a [T]) -> BubbleSortIter<'a, T>
where
    T: Hash + Eq,
{
    iterable_bubble_sort_from_perm(source, generate_permutation(source, target))
}

#[cfg(test)]
mod tests {
    use std::str::from_utf8;

    use super::*;

    #[test]
    fn test_get_permutations() {
        assert_eq!(
            generate_permutation(b"abaca", b"aaacb"),
            PermutationResult {
                perm: vec![0, 2, 4, 3, 1],
                unmatched_source_indices: vec![],
                unmatched_target_indices: vec![],
            }
        );
    }

    #[test]
    fn test_get_permutations_with_mismatch() {
        assert_eq!(
            generate_permutation(b"abacaF", b"aaaacb"),
            PermutationResult {
                perm: vec![0, 2, 4, 3, 1],
                unmatched_source_indices: vec![5],
                unmatched_target_indices: vec![3],
            }
        );
    }

    #[test]
    fn test_get_permutations_with_lots_of_mismatches() {
        // Once where mismatches are the only differences
        assert_eq!(
            generate_permutation(b"AabcefGHijmNNp", b"abcDefijKLmOpQRST"),
            PermutationResult {
                perm: vec![1, 2, 3, 4, 5, 8, 9, 10, 13],
                unmatched_source_indices: vec![0, 6, 7, 11, 12],
                unmatched_target_indices: vec![3, 8, 9, 11, 13, 14, 15, 16],
            }
        );

        // And again with some shuffling on the source. The unmatched target
        // indices shouldn't change
        assert_eq!(
            generate_permutation(b"fceaiGHNpbAjmN", b"abcDefijKLmOpQRST"),
            PermutationResult {
                perm: vec![3, 9, 1, 2, 0, 4, 11, 12, 8],
                unmatched_source_indices: vec![5, 6, 7, 10, 13],
                unmatched_target_indices: vec![3, 8, 9, 11, 13, 14, 15, 16],
            }
        );
    }
    macro_rules! assert_sortable_range_eq {
        ([$($full:expr),*], [$($expected_sortable_range:expr),*]; where max_index = $max_index:expr) => {
            let full_perm : &[usize] = &[$($full),*];
            let range = find_sortable_range(full_perm, $max_index);
            let actual_sortable_range = &full_perm[range];
            let expected_sortable_range: &[usize] = &[$($expected_sortable_range),*];
            assert_eq!(actual_sortable_range, expected_sortable_range);
        };
    }

    #[test]
    fn test_find_sortable_range() {
        assert_sortable_range_eq!([0, 1, 2, 3, 4, 5, 6], []; where max_index = 6);
        assert_sortable_range_eq!([0, 1, 2, 4, 3, 5, 6], [4, 3]; where max_index = 6);
        assert_sortable_range_eq!([0, 1, 2, 4, 3, 5, 6], [4, 3, 5, 6]; where max_index = 7);
        assert_sortable_range_eq!([1, 2, 4, 3, 5, 6], [1, 2, 4, 3]; where max_index = 6);
    }

    #[test]
    fn test_bubble_sort_distance() {
        assert_eq!(bubble_sort_distance(b"abaca", b"aaacb"), 4);
    }

    macro_rules! assert_perm_eq {
        ($expected:expr, $perm:expr) => {
            assert_eq!(
                Ok($expected),
                from_utf8(&$perm.cloned().collect::<Vec<u8>>())
            );
        };
    }

    #[test]
    fn test_bubble_sort_iterator() {
        let mut it = iterable_bubble_sort(b"abaca", b"aaacb");

        // walk backwards through permutations

        assert_perm_eq!("abaca", it.current());
        assert_perm_eq!("aabca", it.next_back().expect("Should be 1 of 4"));
        assert_perm_eq!("aacba", it.next_back().expect("Should be 2 of 4"));
        assert_perm_eq!("aacab", it.next_back().expect("Should be 3 of 4"));
        assert_perm_eq!("aaacb", it.next_back().expect("Should be 4 of 4"));

        // walk forwards through permutations

        assert_perm_eq!("aaacb", it.current());
        assert_perm_eq!("aacab", it.next().expect("Should be 1 of 4"));
        assert_perm_eq!("aacba", it.next().expect("Should be 2 of 4"));
        assert_perm_eq!("aabca", it.next().expect("Should be 3 of 4"));
        assert_perm_eq!("abaca", it.next().expect("Should be 4 of 4"));
    }

    #[test]
    fn test_bubble_sort_distance_mismatch() {
        assert_eq!(bubble_sort_distance(b"afbaca", b"aadacb"), 4);
    }

    #[test]
    fn test_bubble_sort_iterator_mismatch() {
        let mut it = iterable_bubble_sort(b"afbaca", b"aadacb");

        assert_eq!(it.edit_distance(), it.swap_distance() + 2);

        // walk backwards through permutations

        assert_perm_eq!("abaca", it.current());
        assert_perm_eq!("aabca", it.next_back().expect("Should be 1 of 4"));
        assert_perm_eq!("aacba", it.next_back().expect("Should be 2 of 4"));
        assert_perm_eq!("aacab", it.next_back().expect("Should be 3 of 4"));
        assert_perm_eq!("aaacb", it.next_back().expect("Should be 4 of 4"));

        // walk forwards through permutations

        assert_perm_eq!("aaacb", it.current());
        assert_perm_eq!("aacab", it.next().expect("Should be 1 of 4"));
        assert_perm_eq!("aacba", it.next().expect("Should be 2 of 4"));
        assert_perm_eq!("aabca", it.next().expect("Should be 3 of 4"));
        assert_perm_eq!("abaca", it.next().expect("Should be 4 of 4"));
    }

    #[test]
    pub fn test_midpoint() {
        let mut it = iterable_bubble_sort(b"baaaa", b"aaaab");

        assert_perm_eq!("aabaa", it.midpoint());
    }

    /// This macro implifies checking for the correctness of the `expanded_perm`
    /// function. You pass in a source and target string, and it checks the
    /// produced expanded perm against the actual after mapping unmatched
    /// characters in the target to `_`
    macro_rules! assert_expanded_target_eq {
        ($expected:expr, expand_target($src:expr, $target:expr)) => {
            assert_eq!(
                Ok($expected),
                from_utf8(
                    &generate_expanded_permutation($src, $target)
                        .perm
                        .into_iter()
                        .map(|i_opt| i_opt.map(|i| $src[i]).unwrap_or(b'_'))
                        .collect::<Vec<u8>>()
                )
            );
        };
    }

    #[test]
    pub fn test_expanted_perm() {
        assert_expanded_target_eq!("a_", expand_target(b"a", b"ab"));
        assert_expanded_target_eq!("aaacb", expand_target(b"abaca", b"aaacb"));
        assert_expanded_target_eq!("_aa_ac_b", expand_target(b"aVbaFGcaER", b"PaaAacTb"));
    }
}
